This R Markdown script analyses data from the PAL (probabilistic associative learning) task of the EMBA project. HGF parameters were extrated based on the subject-specific reaction times beforehand in MATLAB.
# number of simulations
nsim = 250
# set number of iterations and warmup for models
iter = 3000
warm = 1000
# set the seed
set.seed(2468)
The following packages are used in this RMarkdown file:
## [1] "R version 4.5.1 (2025-06-13)"
## [1] "knitr version 1.50"
## [1] "ggplot2 version 4.0.0"
## [1] "brms version 2.22.0"
## [1] "designr version 0.1.13"
## [1] "bridgesampling version 1.1.2"
## [1] "tidyverse version 2.0.0"
## [1] "ggpubr version 0.6.1"
## [1] "ggrain version 0.0.4"
## [1] "bayesplot version 1.13.0"
## [1] "SBC version 0.3.0.9000"
## [1] "rstatix version 0.7.2"
## [1] "easystats version 0.7.5"
## [1] "BayesFactor version 0.9.12.4.7"
## [1] "bayestestR version 0.17.0"
First, we load the parameters from the winning model.
# get HGF parameters
df.hgf = read_csv(file.path("HGF_results/main", "eHGF-L21_results.csv")) %>%
merge(., read_csv("../data/PAL-ADHD_data.csv", show_col_types = F) %>%
select(subID, EDT, adhd.meds.bin) %>% distinct()) %>%
mutate_if(is.character, as.factor)
# get belief state trajectories
df.trj = read_csv(file.path("HGF_results/main", "eHGF-L21_traj.csv"))
# extract the absolute changes in learning rate for the phases
df.upd = df.trj %>%
select(subID, diagnosis, trl, alpha2, alpha3) %>% ungroup() %>%
mutate(
# code the phases > only take the beginning and end of volatile
phase = case_when(
trl < 73 ~ "pre",
trl > 264 ~ "post",
trl < 145 ~ "vol1",
trl > 192 ~ "vol2"
)
) %>%
drop_na() %>%
group_by(subID, diagnosis, phase) %>%
summarise(
alpha2 = median(alpha2),
alpha3 = median(alpha3)
) %>%
pivot_wider(names_from = phase, id_cols = c(subID, diagnosis), values_from = starts_with("alpha")) %>%
group_by(subID, diagnosis) %>%
summarise(
alpha2_pre2vol = abs(alpha2_pre - alpha2_vol1),
alpha2_vol2post = abs(alpha2_post - alpha2_vol2),
alpha3_pre2vol = abs(alpha3_pre - alpha3_vol1),
alpha3_vol2post = abs(alpha3_post - alpha3_vol2)
) %>%
pivot_longer(cols = starts_with("alpha")) %>%
separate(name, into = c("level", "change")) %>%
merge(., df.hgf %>% select(subID, EDT)) %>%
mutate_if(is.character, as.factor)
# check whether there are LME differences between the diagnostic groups
kable(df.hgf %>% group_by(diagnosis) %>% shapiro_test(LME)) # all normally distributed
| diagnosis | variable | statistic | p |
|---|---|---|---|
| ADHD | LME | 0.9624505 | 0.5403480 |
| BOTH | LME | 0.9732667 | 0.7853028 |
| COMP | LME | 0.9722848 | 0.7633392 |
if (file.exists(file.path(brms_dir, "aov_lme.rds"))) {
aov.lme = readRDS(file.path(brms_dir, "aov_lme.rds"))
} else {
aov.lme = anovaBF(LME ~ diagnosis, data = df.hgf)
saveRDS(aov.lme, file.path(brms_dir, "aov_lme.rds"))
}
aov.lme@bayesFactor
## bf error time code
## diagnosis -0.5840522 9.993788e-05 Thu Oct 30 11:19:46 2025 eca1873f721
There is anecdotal evidence against a difference in LME between diagnostic groups. This suggests that the eHGF model fit comparably well to the subjects of the different groups. Therefore, we move on to analyse its parameters.
The response model best fitting to our data was the one employed by Lawson et al. (2021): \[\log{RT} = \beta_0 + \beta_1 \times surprise_{stimulus} + \beta_2 \times pwPE + \beta_3 \times volatility_{phasic}\] Next, we use sum contrast coding for all of our categorical predictors.
# set and print the contrasts
contrasts(df.hgf$diagnosis) = contr.sum(3)
contrasts(df.hgf$diagnosis)
## [,1] [,2]
## ADHD 1 0
## BOTH 0 1
## COMP -1 -1
contrasts(df.hgf$adhd.meds.bin) = contr.sum(2)[c(2,1)]
contrasts(df.hgf$adhd.meds.bin)
## [,1]
## FALSE -1
## TRUE 1
contrasts(df.upd$diagnosis) = contr.sum(3)[c(2,1,3),]
contrasts(df.upd$diagnosis)
## [,1] [,2]
## ADHD 0 1
## BOTH 1 0
## COMP -1 -1
contrasts(df.upd$change) = contr.sum(2)
contrasts(df.upd$change)
## [,1]
## pre2vol 1
## vol2post -1
contrasts(df.upd$level) = contr.sum(2)
contrasts(df.upd$level)
## [,1]
## alpha2 1
## alpha3 -1
# model formula
f.om2 = brms::bf( om2 ~ diagnosis )
# set weakly informative priors
priors = c(
prior(normal(0, 4), class = Intercept),
prior(normal(0, 0.50), class = sigma),
prior(normal(0, 0.50), class = b)
)
# change Intercept based on empirical priors used in the HGF model
priors = priors %>%
mutate(
prior = if_else(
class == "Intercept",
gsub("\\(.*,", paste0("(", mean(df.hgf$om2mu), ", "), prior), prior),
prior = if_else(
class == "Intercept",
gsub(" .*\\)", paste0(" ", mean(df.hgf$om2sa), ")"), prior), prior)
)
kable(priors)
| prior | class | coef | group | resp | dpar | nlpar | lb | ub | source |
|---|---|---|---|---|---|---|---|---|---|
| normal(-6.921, 8.7788) | Intercept | NA | NA | user | |||||
| normal(0, 0.5) | sigma | NA | NA | user | |||||
| normal(0, 0.5) | b | NA | NA | user |
As the next step, we fit the model, check whether there are divergence or rhat issues, and then check whether the chains have converged.
# fit the final model
m.om2 = brm(f.om2, seed = 2288,
df.hgf, prior = priors,
iter = iter, warmup = warm,
backend = "cmdstanr", threads = threading(t),
file = file.path(brms_dir, "m_hgf_om2"),
save_pars = save_pars(all = TRUE)
)
rstan::check_hmc_diagnostics(m.om2$fit)
##
## Divergences:
## 0 of 8000 iterations ended with a divergence.
##
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
##
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m.om2) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m.om2)
mcmc_trace(post.draws, regex_pars = "^b_",
facet_args = list(ncol = 3)) +
scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.
# get posterior predictions
post.pred = posterior_predict(m.om2, ndraws = nsim)
# check the fit of the predicted data compared to the real data
p1 = pp_check(m.om2, ndraws = nsim) +
theme_bw() + theme(legend.position = "none")
# distributions of means compared to the real values per group
p2 = ppc_stat_grouped(df.hgf$om2, post.pred, df.hgf$diagnosis) +
theme_bw() + theme(legend.position = "none")
p = ggarrange(p1, p2,
nrow = 2, ncol = 1, labels = "AUTO")
annotate_figure(p, top = text_grob("Posterior predictive checks",
face = "bold", size = 14))
Similar to above, the simulated data based on the model fits well with the real data, although it doesn’t reproduce the overall shape.
Now that we are convinced that we can trust our model, we have a look at its estimate and use the hypothesis function to assess our hypotheses and perform explorative tests.
# print a summary
summary(m.om2)
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: om2 ~ diagnosis
## Data: df.hgf (Number of observations: 66)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept -5.93 0.28 -6.48 -5.37 1.00 8326 6174
## diagnosis1 0.19 0.30 -0.39 0.78 1.00 7098 6186
## diagnosis2 0.15 0.30 -0.45 0.73 1.00 7272 5704
##
## Further Distributional Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 2.28 0.16 1.99 2.61 1.00 8008 5792
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# get the estimates and compute group comparisons
df.m.om2 = post.draws %>%
select(starts_with("b_")) %>%
mutate(
ADHD = b_Intercept + b_diagnosis1,
BOTH = b_Intercept + b_diagnosis2,
COMP = b_Intercept - b_diagnosis1 - b_diagnosis2,
`h3c_ADHDvCOMP` = ADHD - COMP,
`e1_BOTHvCOMP` = BOTH - COMP,
`e2_ADHDvBOTH` = ADHD - BOTH,
)
# plot the posterior distributions
df.m.om2 %>%
select(ADHD, BOTH, COMP) %>%
pivot_longer(cols = everything(), names_to = "coef", values_to = "estimate") %>%
ggplot(aes(x = estimate, y = coef), fill = c_light) +
geom_vline(xintercept = mean(df.m.om2$b_Intercept), linetype = 'dashed') +
ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) + theme_bw() +
theme(legend.position = "none")
# H3c: COMP != ADHD
h3c = hypothesis(m.om2, "0 < 2*diagnosis1 + diagnosis2")
h3c$hypothesis
## Hypothesis Estimate Est.Error CI.Lower CI.Upper
## 1 (0)-(2*diagnosis1+diagnosis2) < 0 -0.5340246 0.5734074 -1.483274 0.4013223
## Evid.Ratio Post.Prob Star
## 1 4.602241 0.8215
# Explore BOTH
e1 = hypothesis(m.om2, "0 < diagnosis1 + 2*diagnosis2", alpha = 0.025)
e1$hypothesis
## Hypothesis Estimate Est.Error CI.Lower CI.Upper
## 1 (0)-(diagnosis1+2*diagnosis2) < 0 -0.485513 0.5755258 -1.595381 0.632044
## Evid.Ratio Post.Prob Star
## 1 4.076142 0.803
e2 = hypothesis(m.om2, "diagnosis1 > diagnosis2", alpha= 0.025)
e2$hypothesis
## Hypothesis Estimate Est.Error CI.Lower CI.Upper
## 1 (diagnosis1)-(diagnosis2) > 0 0.04851162 0.4891632 -0.9173199 1.01821
## Evid.Ratio Post.Prob Star
## 1 1.145923 0.534
# equivalence
equivalence_test(df.m.om2 %>% select(starts_with("h") | starts_with("e")),
range = rope_range(m.om2))
## # Test for Practical Equivalence
##
## ROPE: [-0.26 0.26]
##
## Parameter | H0 | inside ROPE | 95% HDI
## -------------------------------------------------------
## h3c_ADHDvCOMP | Undecided | 24.05 % | [-0.59, 1.66]
## e1_BOTHvCOMP | Undecided | 25.34 % | [-0.63, 1.60]
## e2_ADHDvBOTH | Undecided | 42.87 % | [-0.92, 1.02]
# calculate effect sizes
df.effect = post.draws %>%
mutate(across(starts_with("sd")|starts_with("sigma"), ~.^2)) %>%
mutate(
sumvar = sqrt(rowSums(select(., starts_with("sd")|starts_with("sigma")))),
h3c = (2*`b_diagnosis1` + `b_diagnosis2`) / sumvar,
e1 = (`b_diagnosis1` + 2*`b_diagnosis2`) / sumvar,
e2 = -(-`b_diagnosis1` + `b_diagnosis2`) / sumvar
)
kable(df.effect %>% select(starts_with("e")|starts_with("h")) %>%
pivot_longer(cols = everything(), values_to = "estimate") %>%
group_by(name) %>%
summarise(
ci.lo = lower_ci(estimate),
mean = mean(estimate),
ci.hi = upper_ci(estimate),
interpret = interpret_cohens_d(mean)
), digits = 3
)
| name | ci.lo | mean | ci.hi | interpret |
|---|---|---|---|---|
| e1 | -0.279 | 0.214 | 0.700 | small |
| e2 | -0.407 | 0.021 | 0.448 | very small |
| h3c | -0.252 | 0.235 | 0.726 | small |
estimate = -0.53 [-1.48, 0.4], posterior probability = 82.15%
Predicting whether someone has ADHD or not based on the HGF parameters.
# recode the order and scale the predictors
df.hgf = df.hgf %>%
mutate(
group = case_when(
diagnosis == "COMP" ~ 0,
diagnosis != "COMP" & adhd.meds.bin == "FALSE" ~ 1,
T ~ NA
),
group.meds = if_else(adhd.meds.bin == "FALSE",
if_else(diagnosis == "COMP", NA, 0),
1)
) %>% mutate(across(c(be1, be2, be3, ze, om2, om3), scale_this, .names = "s{.col}"))
kable(df.hgf %>% select(diagnosis, group, group.meds) %>% distinct(),
caption = "Coding for the order in the Bernoulli models")
| diagnosis | group | group.meds |
|---|---|---|
| BOTH | NA | 1 |
| ADHD | NA | 1 |
| ADHD | 1 | 0 |
| COMP | 0 | NA |
| BOTH | 1 | 0 |
# model formula
f = brms::bf( group ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3 )
f
## group ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3
# Bernoulli
priors.bern = c(
prior(normal(0.50, 0.50), class = Intercept), # roughly 1:1
prior(normal(0, 1.00), class = b)
)
# fit the final model
m = brm(f,
df.hgf, prior = priors.bern,
family = bernoulli(link = "logit"),
iter = iter, warmup = warm,
backend = "cmdstanr", threads = threading(8),
file = file.path(brms_dir, "m_hgf_bern_adhd"),
seed = 4858
)
rstan::check_hmc_diagnostics(m$fit)
##
## Divergences:
## 0 of 8000 iterations ended with a divergence.
##
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
##
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m)
mcmc_trace(post.draws, regex_pars = "^b_",
facet_args = list(ncol = 4)) +
scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.
# get posterior predictions
post.pred = posterior_predict(m, ndraws = nsim)
# check the fit of the predicted data compared to the real data
p = ppc_bars(df.hgf[!is.na(df.hgf$group),]$group, post.pred) +
theme_bw() + theme(legend.position = "none")
annotate_figure(p, top = text_grob("Posterior predictive checks",
face = "bold", size = 14))
The overall simulated data fits reasonably well. Now that we are convinced that we can trust our model, we have a look at its estimates.
# print a summary
summary(m)
## Family: bernoulli
## Links: mu = logit
## Formula: group ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3
## Data: df.hgf (Number of observations: 41)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept 0.06 0.29 -0.52 0.63 1.00 11663 5569
## sbe1 0.18 0.36 -0.53 0.91 1.00 11572 5518
## sbe2 -0.09 0.37 -0.81 0.65 1.00 10220 6430
## sbe3 -0.19 0.36 -0.89 0.51 1.00 12722 6120
## sze 0.33 0.42 -0.48 1.17 1.00 10045 5970
## som2 0.66 0.36 -0.01 1.40 1.00 11259 6381
## som3 -0.04 0.44 -0.89 0.82 1.00 9145 6332
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# plot the posterior distributions
post.draws %>%
select(starts_with("b_") & !starts_with("b_Int")) %>%
pivot_longer(cols = starts_with("b_"), names_to = "coef", values_to = "estimate") %>%
mutate(
coef = substr(coef, 3, nchar(coef)),
coef = fct_reorder(coef, desc(estimate))
) %>%
group_by(coef) %>%
mutate(
cred = case_when(
(mean(estimate) < 0 & quantile(estimate, probs = 0.975) < 0) |
(mean(estimate) > 0 & quantile(estimate, probs = 0.025) > 0) ~ "credible",
T ~ "not credible"
)
) %>% ungroup() %>%
ggplot(aes(x = estimate, y = coef, fill = cred)) +
geom_vline(xintercept = 0, linetype = 'dashed') +
ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) +
scale_fill_manual(values = c("credible" = c_dark, "not credible" = c_light)) +
theme_bw() + theme(legend.position = "bottom", legend.direction = "horizontal")
e1 = hypothesis(m, "0 > -som2", alpha = 0.025)
e1$hypothesis
## Hypothesis Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob
## 1 (0)-(-som2) > 0 0.6647926 0.3619556 -0.00997184 1.401287 36.91469 0.973625
## Star
## 1
equivalence_test(m)
## # Test for Practical Equivalence
##
## ROPE: [-0.18 0.18]
##
## Parameter | H0 | inside ROPE | 95% HDI
## -------------------------------------------------------
## Intercept | Undecided | 47.91 % | [-0.52, 0.63]
## sbe1 | Undecided | 36.29 % | [-0.53, 0.91]
## sbe2 | Undecided | 37.88 % | [-0.81, 0.65]
## sbe3 | Undecided | 36.39 % | [-0.89, 0.51]
## sze | Undecided | 27.62 % | [-0.48, 1.17]
## som2 | Undecided | 6.51 % | [-9.97e-03, 1.40]
## som3 | Undecided | 33.36 % | [-0.89, 0.82]
# effect sizes
kable(post.draws %>% select(starts_with("b_s")) %>%
pivot_longer(cols = everything(), values_to = "estimate") %>%
group_by(name) %>%
summarise(
ci.lo = lower_ci(estimate),
mean = mean(estimate),
ci.hi = upper_ci(estimate),
interpret = interpret_cohens_d(mean)
), digits = 3
)
| name | ci.lo | mean | ci.hi | interpret |
|---|---|---|---|---|
| b_sbe1 | -0.533 | 0.182 | 0.914 | very small |
| b_sbe2 | -0.811 | -0.093 | 0.650 | very small |
| b_sbe3 | -0.891 | -0.186 | 0.507 | very small |
| b_som2 | -0.010 | 0.665 | 1.401 | medium |
| b_som3 | -0.892 | -0.037 | 0.821 | very small |
| b_sze | -0.477 | 0.329 | 1.174 | small |
Predicting whether someone with ADHD is taking medication or not based on the HGF parameters.
# model formula
f = brms::bf( group.meds ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3 )
f
## group.meds ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3
# fit the final model
m = brm(f,
df.hgf, prior = priors.bern,
family = bernoulli(link = "logit"),
iter = iter, warmup = warm,
backend = "cmdstanr", threads = threading(8),
file = file.path(brms_dir, "m_hgf_bern_meds"),
seed = 8428
)
rstan::check_hmc_diagnostics(m$fit)
##
## Divergences:
## 0 of 8000 iterations ended with a divergence.
##
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
##
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m)
mcmc_trace(post.draws, regex_pars = "^b_",
facet_args = list(ncol = 4)) +
scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.
# get posterior predictions
post.pred = posterior_predict(m, ndraws = nsim)
# check the fit of the predicted data compared to the real data
p = ppc_bars(df.hgf[!is.na(df.hgf$group.meds),]$group.meds, post.pred) +
theme_bw() + theme(legend.position = "none")
annotate_figure(p, top = text_grob("Posterior predictive checks",
face = "bold", size = 14))
The overall simulated data fits reasonably well. Now that we are convinced that we can trust our model, we have a look at its estimates.
# print a summary
summary(m)
## Family: bernoulli
## Links: mu = logit
## Formula: group.meds ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3
## Data: df.hgf (Number of observations: 44)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept 0.39 0.29 -0.17 0.96 1.00 13857 5726
## sbe1 0.19 0.35 -0.49 0.90 1.00 11114 6722
## sbe2 0.28 0.36 -0.43 1.00 1.00 9999 6075
## sbe3 0.04 0.33 -0.60 0.69 1.00 11532 6472
## sze 0.20 0.35 -0.46 0.91 1.00 11852 6181
## som2 -0.52 0.33 -1.20 0.11 1.00 10907 6430
## som3 -0.18 0.38 -0.94 0.56 1.00 11295 5957
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# plot the posterior distributions
post.draws %>%
select(starts_with("b_") & !starts_with("b_Int")) %>%
pivot_longer(cols = starts_with("b_"), names_to = "coef", values_to = "estimate") %>%
mutate(
coef = substr(coef, 3, nchar(coef)),
coef = fct_reorder(coef, desc(estimate))
) %>%
group_by(coef) %>%
mutate(
cred = case_when(
(mean(estimate) < 0 & quantile(estimate, probs = 0.975) < 0) |
(mean(estimate) > 0 & quantile(estimate, probs = 0.025) > 0) ~ "credible",
T ~ "not credible"
)
) %>% ungroup() %>%
ggplot(aes(x = estimate, y = coef, fill = cred)) +
geom_vline(xintercept = 0, linetype = 'dashed') +
ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) +
scale_fill_manual(values = c("credible" = c_dark, "not credible" = c_light)) +
theme_bw() + theme(legend.position = "bottom", legend.direction = "horizontal")
e1 = hypothesis(m, "0 > som2", alpha = 0.025)
e1$hypothesis
## Hypothesis Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob
## 1 (0)-(som2) > 0 0.522949 0.3330175 -0.1088378 1.203256 18.95012 0.949875
## Star
## 1
equivalence_test(m)
## # Test for Practical Equivalence
##
## ROPE: [-0.18 0.18]
##
## Parameter | H0 | inside ROPE | 95% HDI
## ---------------------------------------------------
## Intercept | Undecided | 22.33 % | [-0.17, 0.96]
## sbe1 | Undecided | 36.91 % | [-0.49, 0.90]
## sbe2 | Undecided | 31.49 % | [-0.43, 1.00]
## sbe3 | Undecided | 44.79 % | [-0.60, 0.69]
## sze | Undecided | 37.37 % | [-0.46, 0.91]
## som2 | Undecided | 13.16 % | [-1.20, 0.11]
## som3 | Undecided | 35.99 % | [-0.94, 0.56]
# effect sizes
kable(post.draws %>% select(starts_with("b_s")) %>%
pivot_longer(cols = everything(), values_to = "estimate") %>%
group_by(name) %>%
summarise(
ci.lo = lower_ci(estimate),
mean = mean(estimate),
ci.hi = upper_ci(estimate),
interpret = interpret_cohens_d(mean)
), digits = 3
)
| name | ci.lo | mean | ci.hi | interpret |
|---|---|---|---|---|
| b_sbe1 | -0.494 | 0.185 | 0.899 | very small |
| b_sbe2 | -0.430 | 0.276 | 0.998 | small |
| b_sbe3 | -0.604 | 0.039 | 0.691 | very small |
| b_som2 | -1.203 | -0.523 | 0.109 | medium |
| b_som3 | -0.938 | -0.180 | 0.561 | very small |
| b_sze | -0.461 | 0.202 | 0.907 | small |
p = df.hgf %>%
mutate(diagnosis = if_else(diagnosis == "BOTH", "ADHD+ASD", diagnosis)) %>%
select(subID, diagnosis, be1, be2, be3, ze, om2, om3) %>% #
pivot_longer(cols = c(be1, be2, be3, ze, om2, om3),
names_to = "parameter") %>%
mutate(
parameter = factor(case_match(parameter,
"be1" ~ "stimulus surprise",
"be2" ~ "precision-weighted PE",
"be3" ~ "phasic volatility",
"ze" ~ "Sigma (decision noise)",
"om2" ~ "cue-outcome tonic volatility",
"om3" ~ "environmental tonic volatility"
), levels = c("cue-outcome tonic volatility",
"environmental tonic volatility",
"stimulus surprise",
"precision-weighted PE",
"phasic volatility",
"Sigma (decision noise)"))
) %>%
ggplot(aes(x = 1, y = value, fill = diagnosis, colour = diagnosis)) + #
geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
scale_fill_manual(values = col.grp) +
scale_color_manual(values = col.grp) +
facet_wrap(. ~ parameter, scales = "free", ncol = 3) +
theme_bw() +
theme(legend.position = "bottom", plot.title = element_blank(),
axis.title.y = element_blank(), axis.title.x = element_blank(),
text = element_text(size = 13), axis.text.x=element_blank(),
axis.ticks.x=element_blank(), legend.direction = "horizontal",
legend.title = element_blank(),
legend.margin=margin(0,0,0,0),
legend.box.margin=margin(-5,0,0,0))
p.a = annotate_figure(p, top = text_grob("Participant-specific HGF parameters",
face = "bold", size = 14))
ggsave("plots/FigHGF.svg", plot = p.a, units = "cm", width = 27, height = 13.5)
# include medication
df.hgf %>%
mutate(diagnosis = if_else(diagnosis == "BOTH", "ADHD+ASD", diagnosis),
adhd.meds.bin = case_when(adhd.meds.bin == "TRUE" ~ "medicated",
T ~ ""),
group = paste0(diagnosis, adhd.meds.bin)) %>%
select(subID, diagnosis, group, be1, be2, be3, ze, om2, om3) %>% #
pivot_longer(cols = c(be1, be2, be3, ze, om2, om3),
names_to = "parameter") %>%
mutate(
parameter = factor(case_match(parameter,
"be1" ~ "stimulus surprise",
"be2" ~ "precision-weighted PE",
"be3" ~ "phasic volatility",
"ze" ~ "Sigma (decision noise)",
"om2" ~ "2nd tonic volatility",
"om3" ~ "3rd tonic volatility"
), levels = c("2nd tonic volatility",
"3rd tonic volatility",
"stimulus surprise",
"precision-weighted PE",
"phasic volatility",
"Sigma (decision noise)"))
) %>%
ggplot(aes(x = diagnosis, y = value, fill = group, colour = group)) + #
geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
#scale_fill_manual(values = col.grp) +
#scale_color_manual(values = col.grp) +
facet_wrap(. ~ parameter, scales = "free", ncol = 3) +
theme_bw() +
theme(legend.position = "bottom", plot.title = element_blank(),
axis.title.y = element_blank(), axis.title.x = element_blank(),
text = element_text(size = 13), axis.text.x=element_blank(),
axis.ticks.x=element_blank(), legend.direction = "horizontal",
legend.title = element_blank(),
legend.margin=margin(0,0,0,0),
legend.box.margin=margin(-5,0,0,0))
# model formula
f.alpha = brms::bf( value ~ diagnosis * level * change + (level + change | subID) )
# set weakly informative priors taking Lawson 2017 into consideration
priors = c(
prior(normal(-5, 2), class = Intercept),
prior(normal(0.5, 0.5), class = sigma),
prior(normal(0.5, 0.5), class = sd),
prior(lkj(2), class = cor),
prior(normal(0, 1.0), class = b) # probably big difference between levels
)
As the next step, we fit the model, check whether there are divergence or rhat issues, and then check whether the chains have converged.
# fit the final model
m.alpha = brm(f.alpha, family = lognormal,
df.upd, prior = priors, seed = 6688,
iter = iter, warmup = warm,
backend = "cmdstanr", threads = threading(t),
file = file.path(brms_dir, "m_hgf_alpha"),
save_pars = save_pars(all = TRUE)
)
rstan::check_hmc_diagnostics(m.alpha$fit)
##
## Divergences:
## 0 of 8000 iterations ended with a divergence.
##
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
##
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m.alpha) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m.alpha)
mcmc_trace(post.draws, regex_pars = "^b_",
facet_args = list(ncol = 3)) +
scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.
This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.
# get posterior predictions
post.pred = posterior_predict(m.alpha, ndraws = nsim)
# check the fit of the predicted data compared to the real data
p1 = pp_check(m.alpha, ndraws = nsim) +
theme_bw() + theme(legend.position = "none") + xlim(0, 0.10)
# distributions of means compared to the real values per group
p2 = ppc_stat_grouped(df.upd$value, post.pred, df.upd$diagnosis) +
theme_bw() + theme(legend.position = "none")
p3 = ppc_stat_grouped(df.upd$value, post.pred, df.upd$level) +
theme_bw() + theme(legend.position = "none")
p4 = ppc_stat_grouped(df.upd$value, post.pred, df.upd$change) +
theme_bw() + theme(legend.position = "none")
p = ggarrange(p1, p2, p3, p4, ncol = 1)
annotate_figure(p, top = text_grob("Posterior predictive checks",
face = "bold", size = 14))
This model fits the data well enough.
Now that we are convinced that we can trust our model, we have a look at its estimate and use the hypothesis function to assess our hypotheses and perform explorative tests.
# print a summary
summary(m.alpha)
## Family: lognormal
## Links: mu = identity; sigma = identity
## Formula: value ~ diagnosis * level * change + (level + change | subID)
## Data: df.upd (Number of observations: 264)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Multilevel Hyperparameters:
## ~subID (Number of levels: 66)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(Intercept) 1.11 0.12 0.90 1.35 1.00 2373
## sd(level1) 0.82 0.10 0.63 1.03 1.00 2986
## sd(change1) 0.21 0.08 0.05 0.36 1.00 2793
## cor(Intercept,level1) 0.41 0.13 0.15 0.65 1.00 2399
## cor(Intercept,change1) 0.60 0.23 0.02 0.92 1.00 5210
## cor(level1,change1) 0.62 0.23 0.03 0.92 1.00 5018
## Tail_ESS
## sd(Intercept) 4179
## sd(level1) 5020
## sd(change1) 1653
## cor(Intercept,level1) 4229
## cor(Intercept,change1) 3847
## cor(level1,change1) 3687
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept -4.95 0.15 -5.24 -4.66 1.00 1627
## diagnosis1 0.18 0.21 -0.22 0.59 1.00 1712
## diagnosis2 -0.07 0.21 -0.48 0.33 1.00 1640
## level1 0.33 0.12 0.10 0.56 1.00 2420
## change1 0.77 0.07 0.64 0.90 1.00 6477
## diagnosis1:level1 0.06 0.17 -0.26 0.39 1.00 2513
## diagnosis2:level1 0.15 0.16 -0.17 0.47 1.00 2632
## diagnosis1:change1 0.05 0.10 -0.14 0.24 1.00 5609
## diagnosis2:change1 0.04 0.10 -0.15 0.23 1.00 6322
## level1:change1 -0.10 0.06 -0.22 0.03 1.00 12332
## diagnosis1:level1:change1 0.06 0.09 -0.11 0.24 1.00 7030
## diagnosis2:level1:change1 -0.07 0.09 -0.24 0.10 1.00 7296
## Tail_ESS
## Intercept 2760
## diagnosis1 2534
## diagnosis2 2841
## level1 3706
## change1 5573
## diagnosis1:level1 4210
## diagnosis2:level1 4199
## diagnosis1:change1 5829
## diagnosis2:change1 5894
## level1:change1 5555
## diagnosis1:level1:change1 6185
## diagnosis2:level1:change1 5620
##
## Further Distributional Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 1.01 0.07 0.89 1.15 1.00 2726 4036
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# get the estimates and compute group comparisons
df.m.alpha = post.draws %>%
select(starts_with("b_"))
# plot the posterior distributions
df.m.alpha %>%
select(starts_with("b_")) %>%
pivot_longer(cols = starts_with("b_"), names_to = "coef", values_to = "estimate") %>%
subset(!startsWith(coef, "b_Int")) %>%
mutate(
coef = substr(coef, 3, nchar(coef)),
coef = str_replace_all(coef, ":", " x "),
coef = str_replace_all(coef, "diagnosis1", "ADHD"),
coef = str_replace_all(coef, "diagnosis2", "BOTH"),
coef = str_replace_all(coef, "level1", "alpha2"),
coef = str_replace_all(coef, "change1", "pre2vol"),
coef = fct_reorder(coef, desc(estimate))
) %>%
group_by(coef) %>%
mutate(
cred = case_when(
(mean(estimate) < 0 & quantile(estimate, probs = 0.975) < 0) |
(mean(estimate) > 0 & quantile(estimate, probs = 0.025) > 0) ~ "credible",
T ~ "not credible"
)
) %>% ungroup() %>%
ggplot(aes(x = estimate, y = coef, fill = cred)) +
geom_vline(xintercept = 0, linetype = 'dashed') +
ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) + theme_bw() +
scale_fill_manual(values = c(c_dark, c_light)) + theme(legend.position = "none")
# get the design matrix to figure out how to set the contrasts
df.des = cbind(df.upd,
model.matrix(~ diagnosis * level * change, data = df.upd)) %>%
ungroup() %>%
select(-subID, -value) %>% distinct()
# H4c ADHD != COMP
h4c = hypothesis(m.alpha, "0 < 2*diagnosis1 + diagnosis2", alpha = 0.025)
h4c$hypothesis
## Hypothesis Estimate Est.Error CI.Lower CI.Upper
## 1 (0)-(2*diagnosis1+diagnosis2) < 0 -0.2903084 0.3674689 -1.013658 0.4192193
## Evid.Ratio Post.Prob Star
## 1 3.692082 0.786875
# Exploration: alpha3 ADHD != COMP
t(df.des %>%
filter(level == "alpha3" & diagnosis != "BOTH") %>%
group_by(diagnosis) %>%
summarise(across(where(is.numeric), ~ mean(.x))) %>%
arrange(diagnosis) %>%
select(where(is.numeric)) %>%
map_df(~ diff(.x))) # COMP - ADHD
## [,1]
## EDT -0.06595082
## (Intercept) 0.00000000
## diagnosis1 -1.00000000
## diagnosis2 -2.00000000
## level1 0.00000000
## change1 0.00000000
## diagnosis1:level1 1.00000000
## diagnosis2:level1 2.00000000
## diagnosis1:change1 0.00000000
## diagnosis2:change1 0.00000000
## level1:change1 0.00000000
## diagnosis1:level1:change1 0.00000000
## diagnosis2:level1:change1 0.00000000
e1 = hypothesis(m.alpha, "0 > -2*diagnosis1 - diagnosis2 +
2*diagnosis1:level1 + diagnosis2:level1", alpha = 0.025)
e1$hypothesis
## Hypothesis
## 1 (0)-(-2*diagnosis1-diagnosis2+2*diagnosis1:level1+diagnosis2:level1) > 0
## Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob Star
## 1 0.00919434 0.3906923 -0.7748123 0.7637141 1.060793 0.51475
# H4c: alpha2 ADHD != COMP
t(df.des %>%
filter(level == "alpha2" & diagnosis != "BOTH") %>%
group_by(diagnosis) %>%
summarise(across(where(is.numeric), ~ mean(.x))) %>%
arrange(diagnosis) %>%
select(where(is.numeric)) %>%
map_df(~ diff(.x))) # COMP - ADHD
## [,1]
## EDT -0.06595082
## (Intercept) 0.00000000
## diagnosis1 -1.00000000
## diagnosis2 -2.00000000
## level1 0.00000000
## change1 0.00000000
## diagnosis1:level1 -1.00000000
## diagnosis2:level1 -2.00000000
## diagnosis1:change1 0.00000000
## diagnosis2:change1 0.00000000
## level1:change1 0.00000000
## diagnosis1:level1:change1 0.00000000
## diagnosis2:level1:change1 0.00000000
e2 = hypothesis(m.alpha, "0 > -(2*diagnosis1 + diagnosis2 +
2*diagnosis1:level1 + diagnosis2:level1)", alpha = 0.025)
e2$hypothesis
## Hypothesis
## 1 (0)-(-(2*diagnosis1+diagnosis2+2*diagnosis1:level1+diagnosis2:level1)) > 0
## Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob Star
## 1 0.5714225 0.5394324 -0.4726952 1.62411 6.054674 0.85825
# Explore BOTH
e3 = hypothesis(m.alpha, "0 < -(2*diagnosis2 + diagnosis1) +
2*diagnosis2:level1 + diagnosis1:level1", alpha = 0.025)
e3$hypothesis
## Hypothesis
## 1 (0)-(-(2*diagnosis2+diagnosis1)+2*diagnosis2:level1+diagnosis1:level1) < 0
## Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob Star
## 1 -0.323644 0.3870926 -1.08834 0.4232842 3.875076 0.794875
e4 = hypothesis(m.alpha, "0 > -(2*diagnosis2 + diagnosis1 +
2*diagnosis2:level1 + diagnosis1:level1)", alpha = 0.025)
e4$hypothesis
## Hypothesis
## 1 (0)-(-(2*diagnosis2+diagnosis1+2*diagnosis2:level1+diagnosis1:level1)) > 0
## Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob Star
## 1 0.4151504 0.5398817 -0.6248419 1.475844 3.535147 0.7795
# calculate effect sizes
df.effect = post.draws %>%
mutate(across(starts_with("sd")|starts_with("sigma"), ~.^2)) %>%
mutate(
sumvar = sqrt(rowSums(select(., starts_with("sd")|starts_with("sigma")))),
h4c = (2*`b_diagnosis1` + `b_diagnosis2`) / sumvar
)
kable(df.effect %>% select(starts_with("e")|starts_with("h")) %>%
pivot_longer(cols = everything(), values_to = "estimate") %>%
group_by(name) %>%
summarise(
ci.lo = lower_ci(estimate),
mean = mean(estimate),
ci.hi = upper_ci(estimate),
interpret = interpret_cohens_d(mean)
), digits = 3
)
| name | ci.lo | mean | ci.hi | interpret |
|---|---|---|---|---|
| h4c | -0.244 | 0.168 | 0.59 | very small |
h4c ADHD vs. COMP: estimate = -0.29 [-1.01, 0.42], posterior probability = 78.69%
# rank transform the values
df.upd = df.upd %>% ungroup() %>%
mutate(rvalue = rank(value))
if (!file.exists(file.path(brms_dir, "aov_alpha.rds"))) {
aov = anovaBF(rvalue ~ diagnosis * level * change, data = df.upd)
} else {
aov = readRDS(file.path(brms_dir, "aov_alpha.rds"))
}
kable(aov@bayesFactor %>% arrange(desc(bf)) %>%
select(bf) %>% mutate(bf.diff = abs(lead(bf)-bf),
bf.int = interpret_bf(bf.diff, log = T)), digits = 3)
| bf | bf.diff | bf.int | |
|---|---|---|---|
| level + change | 24.839 | 0.728 | anecdotal evidence in favour of |
| level + change + level:change | 24.110 | 1.606 | moderate evidence in favour of |
| change | 22.504 | 0.484 | anecdotal evidence in favour of |
| diagnosis + level + change | 22.020 | 0.614 | anecdotal evidence in favour of |
| diagnosis + level + change + level:change | 21.407 | 0.825 | anecdotal evidence in favour of |
| diagnosis + level + diagnosis:level + change | 20.582 | 0.729 | anecdotal evidence in favour of |
| diagnosis + level + diagnosis:level + change + level:change | 19.852 | 0.135 | anecdotal evidence in favour of |
| diagnosis + change | 19.717 | 0.123 | anecdotal evidence in favour of |
| diagnosis + level + change + diagnosis:change | 19.595 | 0.648 | anecdotal evidence in favour of |
| diagnosis + level + change + diagnosis:change + level:change | 18.946 | 0.837 | anecdotal evidence in favour of |
| diagnosis + level + diagnosis:level + change + diagnosis:change | 18.109 | 0.646 | anecdotal evidence in favour of |
| diagnosis + level + diagnosis:level + change + diagnosis:change + level:change | 17.464 | 0.148 | anecdotal evidence in favour of |
| diagnosis + change + diagnosis:change | 17.315 | 2.001 | moderate evidence in favour of |
| diagnosis + level + diagnosis:level + change + diagnosis:change + level:change + diagnosis:level:change | 15.314 | 13.743 | extreme evidence in favour of |
| level | 1.571 | 2.836 | strong evidence in favour of |
| diagnosis + level | -1.265 | 1.575 | moderate evidence in favour of |
| diagnosis | -2.841 | 0.168 | anecdotal evidence in favour of |
| diagnosis + level + diagnosis:level | -3.008 | NA |
# rain cloud plot
df.upd %>%
ggplot(aes(1, value, fill = diagnosis, colour = diagnosis)) + #
geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
scale_fill_manual(values = col.grp) +
scale_color_manual(values = col.grp) +
facet_wrap(level ~ change, scales = "free") +
labs(title = "Learning rate updates", x = "", y = "") +
theme_bw() +
theme(legend.position = "bottom", plot.title = element_text(hjust = 0.5),
legend.direction = "horizontal", text = element_text(size = 15),
axis.text.x = element_blank(), axis.ticks.x = element_blank())
# Exluding the outliers
df.upd %>%
filter(value < 0.4) %>%
ggplot(aes(1, value, fill = diagnosis, colour = diagnosis)) + #
geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
scale_fill_manual(values = col.grp) +
scale_color_manual(values = col.grp) +
facet_wrap(level ~ change, scales = "free") +
labs(title = "Learning rate updates", x = "", y = "") +
theme_bw() +
theme(legend.position = "bottom", plot.title = element_text(hjust = 0.5),
legend.direction = "horizontal", text = element_text(size = 15),
axis.text.x = element_blank(), axis.ticks.x = element_blank())
df.upd %>% filter(value >= 0.4) %>% group_by(diagnosis) %>% count()
## # A tibble: 3 × 2
## # Groups: diagnosis [3]
## diagnosis n
## <fct> <int>
## 1 ADHD 4
## 2 BOTH 3
## 3 COMP 1
# including medication
df.upd %>%
merge(., df.hgf %>% select(subID, adhd.meds.bin)) %>%
mutate(diagnosis = if_else(diagnosis == "BOTH", "ADHD+ASD", diagnosis),
adhd.meds.bin = case_when(adhd.meds.bin == "TRUE" ~ "medicated",
T ~ ""),
group = paste0(diagnosis, adhd.meds.bin)) %>%
ggplot(aes(diagnosis, value, fill = group, colour = group)) + #
geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
# scale_fill_manual(values = col.grp) +
# scale_color_manual(values = col.grp) +
facet_wrap(level ~ change, scales = "free") +
labs(title = "Learning rate updates", x = "", y = "") +
theme_bw() +
theme(legend.position = "bottom", plot.title = element_text(hjust = 0.5),
legend.direction = "horizontal", text = element_text(size = 15),
axis.text.x = element_blank(), axis.ticks.x = element_blank())